{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "## =====================================================================================\n",
    "## This is a temporarly fix for the freezing and the cuda issues. You can add this\n",
    "## utility script instead of kaggle_l5kit until Kaggle resolve these issues.\n",
    "## \n",
    "## You will be able to train and submit your results, but not all the functionality of\n",
    "## l5kit will work properly.\n",
    "\n",
    "## More details here:\n",
    "## https://www.kaggle.com/c/lyft-motion-prediction-autonomous-vehicles/discussion/177125\n",
    "\n",
    "## this script transports l5kit and dependencies\n",
    "os.system('pip install --target=/kaggle/working pymap3d==2.1.0')\n",
    "os.system('pip install --target=/kaggle/working protobuf==3.12.2')\n",
    "os.system('pip install --target=/kaggle/working transforms3d')\n",
    "os.system('pip install --target=/kaggle/working zarr')\n",
    "os.system('pip install --target=/kaggle/working ptable')\n",
    "\n",
    "os.system('pip install --no-dependencies --target=/kaggle/working l5kit')\n",
    "# os.system('pip install --target=/kaggle/working timm')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "!pip install timm"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "_cell_guid": "79c7e3d0-c299-4dcb-8224-4455121ee9b0",
    "_uuid": "d629ff2d2480ee46fbb7e2d37f6b5fab8052498a"
   },
   "outputs": [],
   "source": [
    "# import packages\n",
    "import os, gc\n",
    "import zarr\n",
    "import numpy as np \n",
    "import pandas as pd \n",
    "from tqdm import tqdm\n",
    "from typing import Dict, List, Callable\n",
    "from collections import Counter\n",
    "from prettytable import PrettyTable\n",
    "from collections import OrderedDict\n",
    "import math\n",
    "import pickle\n",
    "from IPython.display import FileLink\n",
    "\n",
    "#level5 toolkit\n",
    "from l5kit.data import PERCEPTION_LABELS\n",
    "from l5kit.dataset import EgoDataset, AgentDataset\n",
    "from l5kit.data import ChunkedDataset, LocalDataManager\n",
    "\n",
    "# level5 toolkit \n",
    "from l5kit.configs import load_config_data\n",
    "from l5kit.geometry import transform_points\n",
    "from l5kit.rasterization import build_rasterizer\n",
    "from l5kit.visualization import draw_trajectory, draw_reference_trajectory, TARGET_POINTS_COLOR, PREDICTED_POINTS_COLOR, write_gif\n",
    "from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset, export_zarr_to_csv, write_gt_csv\n",
    "from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace\n",
    "\n",
    "# visualization\n",
    "import seaborn as sns\n",
    "import matplotlib.pyplot as plt\n",
    "from matplotlib import animation\n",
    "from colorama import Fore, Back, Style\n",
    "\n",
    "# deep learning\n",
    "import torch\n",
    "from torch import nn, optim, Tensor\n",
    "from torch.utils.data import DataLoader\n",
    "from torchvision.models.resnet import resnet18, resnet50, resnet34\n",
    "from torchvision.models.mobilenet import mobilenet_v2\n",
    "from torch.nn import functional as F\n",
    "import timm\n",
    "\n",
    "# ensemble learning\n",
    "from sklearn.cluster import KMeans\n",
    "\n",
    "# check files in directory\n",
    "print((os.listdir('../input/lyft-motion-prediction-autonomous-vehicles/')))\n",
    "\n",
    "plt.rc('animation', html='jshtml')\n",
    "\n",
    "%matplotlib inline"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# root directory\n",
    "DIR_INPUT = \"/kaggle/input/lyft-motion-prediction-autonomous-vehicles\"\n",
    "\n",
    "# set env variable for data\n",
    "os.environ[\"L5KIT_DATA_FOLDER\"] = DIR_INPUT\n",
    "dm = LocalDataManager(None)\n",
    "# print(training_cfg)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Model Definition\n",
    "\n",
    "## LyftModel - simple Resnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = resnet50(pretrained=True) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "\n",
    "        self.backbone.conv1 = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.conv1.out_channels,\n",
    "            kernel_size=self.backbone.conv1.kernel_size,\n",
    "            stride=self.backbone.conv1.stride,\n",
    "            padding=self.backbone.conv1.padding,\n",
    "            bias=False,\n",
    "        )\n",
    "        \n",
    "        # This is 512 for resnet18 and resnet34;\n",
    "        # And it is 2048 for the other resnets\n",
    "        backbone_out_features = 2048\n",
    "        \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "\n",
    "        # You can add more layers here.\n",
    "        self.head = nn.Sequential(\n",
    "            # nn.Dropout(0.2),\n",
    "            nn.Linear(in_features=backbone_out_features, out_features=4096),\n",
    "        )\n",
    "\n",
    "        self.logit = nn.Linear(4096, out_features=num_targets)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone.conv1(x)\n",
    "        x = self.backbone.bn1(x)\n",
    "        x = self.backbone.relu(x)\n",
    "        x = self.backbone.maxpool(x)\n",
    "\n",
    "        x = self.backbone.layer1(x)\n",
    "        x = self.backbone.layer2(x)\n",
    "        x = self.backbone.layer3(x)\n",
    "        x = self.backbone.layer4(x)\n",
    "\n",
    "        x = self.backbone.avgpool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        \n",
    "        x = self.head(x)\n",
    "        x = self.logit(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftMixModel(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = timm.create_model('mixnet_xl',pretrained=True)\n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "\n",
    "        self.backbone.conv_stem = nn.Conv2d(\n",
    "            num_in_channels,\n",
    "            self.backbone.conv_stem.out_channels,\n",
    "            kernel_size=self.backbone.conv_stem.kernel_size,\n",
    "            stride=self.backbone.conv_stem.stride,\n",
    "            padding=self.backbone.conv_stem.padding,\n",
    "            bias=False,\n",
    "        )\n",
    "        \n",
    "        # This is 512 for resnet18 and resnet34;\n",
    "        # And it is 2048 for the other resnets\n",
    "        backbone_out_features = 1536\n",
    "        \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "\n",
    "        # You can add more layers here.\n",
    "        self.head = nn.Sequential(\n",
    "            # nn.Dropout(0.2),\n",
    "            nn.Linear(in_features=backbone_out_features, out_features=4096),\n",
    "        )\n",
    "\n",
    "        self.logit = nn.Linear(4096, out_features=num_targets)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone.conv_stem(x)\n",
    "        x = self.backbone.bn1(x)\n",
    "        x = self.backbone.act1(x)\n",
    "        \n",
    "        x = self.backbone.blocks(x)\n",
    "\n",
    "        x = self.backbone.conv_head(x)\n",
    "        x = self.backbone.bn2(x)\n",
    "        x = self.backbone.act2(x)\n",
    "        x = self.backbone.global_pool(x)\n",
    "        x = torch.flatten(x, 1)\n",
    "        \n",
    "        x = self.head(x)\n",
    "        x = self.logit(x)\n",
    "        \n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Mixnet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftMixnet(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg, classify='mixnet_l'):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = timm.create_model(classify, pretrained=False) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "        \n",
    "        self.backbone.conv_stem = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.conv_stem.out_channels,\n",
    "            kernel_size=self.backbone.conv_stem.kernel_size,\n",
    "            stride=self.backbone.conv_stem.stride,\n",
    "            padding=self.backbone.conv_stem.padding,\n",
    "            bias=False,\n",
    "        )\n",
    "        \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "        self.logit = nn.Linear(self.backbone.classifier.out_features, out_features=num_targets)\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        x = self.logit(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## MobileNet v2"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "class LyftMobile(nn.Module):\n",
    "    \n",
    "    def __init__(self, cfg):\n",
    "        super().__init__()\n",
    "        \n",
    "        # set pretrained=True while training\n",
    "        self.backbone = mobilenet_v2(pretrained=True) \n",
    "        \n",
    "        num_history_channels = (cfg[\"model_params\"][\"history_num_frames\"] + 1) * 2\n",
    "        num_in_channels = 3 + num_history_channels\n",
    "\n",
    "        self.backbone.features[0][0] = nn.Conv2d(\n",
    "            # num_in_channels = 25\n",
    "            num_in_channels,\n",
    "            self.backbone.features[0][0].out_channels,\n",
    "            kernel_size=self.backbone.features[0][0].kernel_size,\n",
    "            stride=self.backbone.features[0][0].stride,\n",
    "            padding=self.backbone.features[0][0].padding,\n",
    "            bias=False,\n",
    "        )\n",
    "                \n",
    "        # X, Y coords for the future positions (output shape: Bx50x2)\n",
    "        num_targets = 2 * cfg[\"model_params\"][\"future_num_frames\"]\n",
    "\n",
    "        # Fully connected layer.\n",
    "        self.backbone.classifier[1] = nn.Linear(\n",
    "            in_features=self.backbone.classifier[1].in_features,\n",
    "            out_features=num_targets\n",
    "        )\n",
    "\n",
    "        \n",
    "    def forward(self, x):\n",
    "        x = self.backbone(x)\n",
    "        return x"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Multi-Modal"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# --- Function utils ---\n",
    "# Original code from https://github.com/lyft/l5kit/blob/20ab033c01610d711c3d36e1963ecec86e8b85b6/l5kit/l5kit/evaluation/metrics.py\n",
    "\n",
    "def pytorch_neg_multi_log_likelihood_batch(\n",
    "    gt: Tensor, pred: Tensor, confidences: Tensor, avails: Tensor\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "    Compute a negative log-likelihood for the multi-modal scenario.\n",
    "    log-sum-exp trick is used here to avoid underflow and overflow, For more information about it see:\n",
    "    https://en.wikipedia.org/wiki/LogSumExp#log-sum-exp_trick_for_log-domain_calculations\n",
    "    https://timvieira.github.io/blog/post/2014/02/11/exp-normalize-trick/\n",
    "    https://leimao.github.io/blog/LogSumExp/\n",
    "    Args:\n",
    "        gt (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        pred (Tensor): array of shape (bs)x(modes)x(time)x(2D coords)\n",
    "        confidences (Tensor): array of shape (bs)x(modes) with a confidence for each mode in each sample\n",
    "        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep\n",
    "    Returns:\n",
    "        Tensor: negative log-likelihood for this example, a single float number\n",
    "    \"\"\"\n",
    "    assert len(pred.shape) == 4, f\"expected 3D (MxTxC) array for pred, got {pred.shape}\"\n",
    "    batch_size, num_modes, future_len, num_coords = pred.shape\n",
    "\n",
    "    assert gt.shape == (batch_size, future_len, num_coords), f\"expected 2D (Time x Coords) array for gt, got {gt.shape}\"\n",
    "    assert confidences.shape == (batch_size, num_modes), f\"expected 1D (Modes) array for gt, got {confidences.shape}\"\n",
    "    if not torch.allclose(torch.sum(confidences, dim=1), confidences.new_ones((batch_size,))):\n",
    "        print(confidences)\n",
    "    assert torch.allclose(torch.sum(confidences, dim=1), confidences.new_ones((batch_size,))), \"confidences should sum to 1\"\n",
    "    assert avails.shape == (batch_size, future_len), f\"expected 1D (Time) array for gt, got {avails.shape}\"\n",
    "    # assert all data are valid\n",
    "    assert torch.isfinite(pred).all(), \"invalid value found in pred\"\n",
    "    assert torch.isfinite(gt).all(), \"invalid value found in gt\"\n",
    "    assert torch.isfinite(confidences).all(), \"invalid value found in confidences\"\n",
    "    assert torch.isfinite(avails).all(), \"invalid value found in avails\"\n",
    "\n",
    "    # convert to (batch_size, num_modes, future_len, num_coords)\n",
    "    gt = torch.unsqueeze(gt, 1)  # add modes\n",
    "    avails = avails[:, None, :, None]  # add modes and cords\n",
    "\n",
    "    # error (batch_size, num_modes, future_len)\n",
    "    error = torch.sum(((gt - pred) * avails) ** 2, dim=-1)  # reduce coords and use availability\n",
    "\n",
    "    with np.errstate(divide=\"ignore\"):  # when confidence is 0 log goes to -inf, but we're fine with it\n",
    "        # error (batch_size, num_modes)\n",
    "        error = torch.log(confidences) - 0.5 * torch.sum(error, dim=-1)  # reduce time\n",
    "\n",
    "    # use max aggregator on modes for numerical stability\n",
    "    # error (batch_size, num_modes)\n",
    "    max_value, _ = error.max(dim=1, keepdim=True)  # error are negative at this point, so max() gives the minimum one\n",
    "    error = -torch.log(torch.sum(torch.exp(error - max_value), dim=-1, keepdim=True)) - max_value  # reduce modes\n",
    "    # print(\"error\", error)\n",
    "    return torch.mean(error)\n",
    "\n",
    "\n",
    "def pytorch_neg_multi_log_likelihood_single(\n",
    "    gt: Tensor, pred: Tensor, avails: Tensor\n",
    ") -> Tensor:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        gt (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        pred (Tensor): array of shape (bs)x(time)x(2D coords)\n",
    "        avails (Tensor): array of shape (bs)x(time) with the availability for each gt timestep\n",
    "    Returns:\n",
    "        Tensor: negative log-likelihood for this example, a single float number\n",
    "    \"\"\"\n",
    "    # pred (bs)x(time)x(2D coords) --> (bs)x(mode=1)x(time)x(2D coords)\n",
    "    # create confidence (bs)x(mode=1)\n",
    "    batch_size, future_len, num_coords = pred.shape\n",
    "    confidences = pred.new_ones((batch_size, 1))\n",
    "    return pytorch_neg_multi_log_likelihood_batch(gt, pred.unsqueeze(1), confidences, avails)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Validate Configuration"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# validate cfg\n",
    "validate_cfg = {\n",
    "    \n",
    "    'format_version': 4,\n",
    "    'model_params': {\n",
    "        'history_num_frames': 10,\n",
    "        'history_step_size': 1,\n",
    "        'history_delta_time': 0.1,\n",
    "        'future_num_frames': 50,\n",
    "        'future_step_size': 1,\n",
    "        'future_delta_time': 0.1\n",
    "    },\n",
    "    \n",
    "    'raster_params': {\n",
    "        'raster_size': [300, 300],\n",
    "        'pixel_size': [0.5, 0.5],\n",
    "        'ego_center': [0.25, 0.5],\n",
    "        'map_type': 'py_semantic',\n",
    "        'satellite_map_key': 'aerial_map/aerial_map.png',\n",
    "        'semantic_map_key': 'semantic_map/semantic_map.pb',\n",
    "        'dataset_meta_key': 'meta.json',\n",
    "        'filter_agents_threshold': 0.5,\n",
    "        'disable_traffic_light_faces': False\n",
    "    },\n",
    "    \n",
    "    'validate_data_loader': {\n",
    "    'key': 'scenes/validate.zarr',\n",
    "    'batch_size': 6,\n",
    "    'shuffle': False,\n",
    "    'num_workers': 4\n",
    "    }\n",
    "\n",
    "}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Hyperparameters"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else \"cpu\")\n",
    "# choose parameter conf\n",
    "conf = 2\n",
    "\n",
    "if conf == 1:\n",
    "    # resnet stuff\n",
    "    model = LyftModel(training_cfg).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n",
    "    # criterion = nn.MSELoss(reduction=\"none\")\n",
    "\n",
    "if conf == 2:\n",
    "    # mixnet_large\n",
    "\n",
    "    model = LyftMixnet(validate_cfg, 'mixnet_l').to(device)\n",
    "    \n",
    "    WEIGHT_FILE = '../input/resnet-34-pth/model_state_mixnetl_25000_17000_nll.pth'\n",
    "    model_state = torch.load(WEIGHT_FILE, map_location=device)\n",
    "    model.load_state_dict(model_state)\n",
    "    \n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-4, weight_decay=2e-6)\n",
    "    criterion = pytorch_neg_multi_log_likelihood_single\n",
    "    \n",
    "if conf == 3:\n",
    "    # mixnet_medium\n",
    "    model = LyftMixnet(training_cfg, 'mixnet_m').to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=4e-4, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n",
    "\n",
    "if conf == 4:\n",
    "    model = LyftMobile(training_cfg).to(device)\n",
    "    optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=2e-6)\n",
    "    criterion = nn.SmoothL1Loss()\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Read Models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model_paths = [\n",
    "    \"../input/resnet-34-pth/mixnetm_35000.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnet_xl_12000.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnet_xl_nll_8000.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_25000_1000_custom_nll.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_25000_1000_nll.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_25000_9000_nll.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_25000_17000_nll.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_35000.pth\",\n",
    "    \"../input/resnet-34-pth/model_state_mixnetl_25000.pth\",\n",
    "]\n",
    "def load_model(model_paths):\n",
    "    models = []\n",
    "    for i, path in enumerate(model_paths):\n",
    "        if i == 0:\n",
    "            model = LyftMixnet(validate_cfg, 'mixnet_m').to(device)\n",
    "        elif i == 1 or i == 2:\n",
    "            model = LyftMixModel(validate_cfg).to(device)\n",
    "\n",
    "        else:\n",
    "            model = LyftMixnet(validate_cfg, 'mixnet_l').to(device)\n",
    "        model_state = torch.load(path, map_location=device)\n",
    "        model.load_state_dict(model_state)\n",
    "        models.append(model)\n",
    "    return models\n",
    "models = load_model(model_paths)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# validation configuration\n",
    "valid_cfg = validate_cfg[\"validate_data_loader\"]\n",
    "\n",
    "# Rasterizer\n",
    "rasterizer = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "# Validation dataset/dataloader\n",
    "valid_zarr = ChunkedDataset(dm.require(valid_cfg[\"key\"])).open()\n",
    "valid_dataset = AgentDataset(validate_cfg, valid_zarr, rasterizer)\n",
    "whole_size = valid_dataset.__len__()\n",
    "valid_dataset_use, valid_dataset_valid, _ = torch.utils.data.random_split(valid_dataset, [7000, 2000, whole_size-9000], generator=torch.Generator().manual_seed(42))\n",
    "\n",
    "valid_dataloader = DataLoader(valid_dataset_use,\n",
    "                             shuffle=valid_cfg[\"shuffle\"],\n",
    "                             batch_size=valid_cfg[\"batch_size\"],\n",
    "                             num_workers=valid_cfg[\"num_workers\"])\n",
    "\n",
    "valid_dataloader_valid = DataLoader(valid_dataset_valid,\n",
    "                             shuffle=valid_cfg[\"shuffle\"],\n",
    "                             batch_size=valid_cfg[\"batch_size\"],\n",
    "                             num_workers=valid_cfg[\"num_workers\"])\n",
    "\n",
    "print(valid_dataloader.dataset.__len__())\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "metadata": {},
   "outputs": [],
   "source": [
    "def kmeans_ensemble(models:List[Callable], data:Dict, num_cluster=3) -> np.ndarray:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        models (Iterable): list of models\n",
    "        data (Dict): data from dataloader\n",
    "    Returns:\n",
    "        np.ndarray: array of shape (modes)x(timesteps)x(2D coords), predicted tractories modeled by clustering\n",
    "    \"\"\"\n",
    "    assert len(models) >= num_cluster\n",
    "    \n",
    "    inputs = data[\"image\"].to(device)\n",
    "    target_availabilities = data[\"target_availabilities\"].unsqueeze(-1).to(device)\n",
    "    targets = data[\"target_positions\"].to(device)\n",
    "    clusters = []\n",
    "    \n",
    "    batch_size = targets.shape[0]\n",
    "    num_model = len(models)\n",
    "    \n",
    "    for model in models:\n",
    "        outputs = model(inputs)\n",
    "        clusters.append(outputs.cpu().numpy().copy())\n",
    "    \n",
    "    num_outputs = outputs.shape[1] # 100\n",
    "\n",
    "    arr = np.empty((num_model, num_outputs) )\n",
    "    batch_centers = np.empty((batch_size, num_cluster, num_outputs//2, 2))\n",
    "    for i in range(batch_size):\n",
    "        for j in range(num_model):\n",
    "            \n",
    "            assert clusters[j][i].shape == arr[j].shape\n",
    "            arr[j] = clusters[j][i]\n",
    "            \n",
    "        km = KMeans(n_clusters=num_cluster).fit(arr)\n",
    "        centers = km.cluster_centers_\n",
    "        centers = centers.reshape(num_cluster,50,2)\n",
    "        batch_centers[i] = centers\n",
    "    return batch_centers\n",
    "\n",
    "    "
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "metadata": {},
   "outputs": [],
   "source": [
    "def multiple(models:List[Callable], data:Dict) -> np.ndarray:\n",
    "    \"\"\"\n",
    "\n",
    "    Args:\n",
    "        models (Iterable): list of models\n",
    "        data (Dict): data from dataloader\n",
    "    Returns:\n",
    "        np.ndarray: array of shape (modes)x(timesteps)x(2D coords), predicted tractories modeled by clustering\n",
    "    \"\"\"\n",
    "    assert len(models) > 0\n",
    "    \n",
    "    inputs = data[\"image\"].to(device)\n",
    "    target_availabilities = data[\"target_availabilities\"].unsqueeze(-1).to(device)\n",
    "    targets = data[\"target_positions\"].to(device)\n",
    "    res = []\n",
    "    \n",
    "    batch_size = targets.shape[0]\n",
    "    num_model = len(models)\n",
    "    \n",
    "    for i, model in enumerate(models):\n",
    "        outputs = model(inputs)\n",
    "        res.append(outputs.cpu().numpy().copy().reshape((batch_size, 50, 2)))\n",
    "        \n",
    "        \n",
    "    batch_output = np.empty((batch_size, num_model, 50, 2))\n",
    "    \n",
    "    for i in range(batch_size):\n",
    "        for j in range(num_model):\n",
    "            \n",
    "            batch_output[i][j] = res[j][i]\n",
    "            \n",
    "    return batch_output\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1167/1167 [08:19<00:00,  2.34it/s]\n"
     ]
    }
   ],
   "source": [
    "# Final - Evaluate Validation Dataset \n",
    "model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# store information for evaluation\n",
    "future_coords_offsets_pd = []\n",
    "timestamps = []\n",
    "# coordinates ground truth\n",
    "valid_coords_gts = []\n",
    "# target avalabilities\n",
    "target_avail_pd = []\n",
    "agent_ids = []\n",
    "progress_bar = tqdm(valid_dataloader)\n",
    "for data in progress_bar:\n",
    "    \n",
    "    inputs = data[\"image\"].to(device)\n",
    "    target_availabilities = data[\"target_availabilities\"].unsqueeze(-1).to(device)\n",
    "    targets = data[\"target_positions\"].to(device)\n",
    "    \n",
    "    outputs = model(inputs).reshape(targets.shape)\n",
    "    \n",
    "    future_coords_offsets_pd.append(outputs.cpu().numpy().copy())\n",
    "    timestamps.append(data[\"timestamp\"].numpy().copy())\n",
    "    agent_ids.append(data[\"track_id\"].numpy().copy())\n",
    "    valid_coords_gts.append(data[\"target_positions\"].numpy().copy())\n",
    "    target_avail_pd.append(target_availabilities.cpu().numpy().copy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation -- Ensemble"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 40,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      " 17%|█▋        | 198/1167 [01:55<09:23,  1.72it/s]\n"
     ]
    },
    {
     "ename": "KeyboardInterrupt",
     "evalue": "",
     "output_type": "error",
     "traceback": [
      "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m                         Traceback (most recent call last)",
      "\u001b[0;32m<ipython-input-40-9db18da1521c>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n\u001b[1;32m     15\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mdata\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mprogress_bar\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 17\u001b[0;31m     \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmultiple\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mdata\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     18\u001b[0m     \u001b[0mfuture_coords_offsets_pd\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     19\u001b[0m     \u001b[0mtimestamps\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mdata\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m\"timestamp\"\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-34-75ce4e7abb51>\u001b[0m in \u001b[0;36mmultiple\u001b[0;34m(models, data)\u001b[0m\n\u001b[1;32m     19\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     20\u001b[0m     \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmodel\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodels\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 21\u001b[0;31m         \u001b[0moutputs\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minputs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     22\u001b[0m         \u001b[0mres\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0moutputs\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcpu\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mnumpy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcopy\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m50\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     23\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m<ipython-input-6-17bb4aaa2bbb>\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     25\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     26\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 27\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbackbone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     28\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mlogit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     29\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    389\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    390\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 391\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward_features\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    392\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mglobal_pool\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    393\u001b[0m         \u001b[0;32mif\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mdrop_rate\u001b[0m \u001b[0;34m>\u001b[0m \u001b[0;36m0.\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet.py\u001b[0m in \u001b[0;36mforward_features\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    382\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    383\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 384\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mblocks\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    385\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_head\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    386\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/container.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    115\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    116\u001b[0m         \u001b[0;32mfor\u001b[0m \u001b[0mmodule\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 117\u001b[0;31m             \u001b[0minput\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodule\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    118\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    119\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/efficientnet_blocks.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m    255\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    256\u001b[0m         \u001b[0;31m# Point-wise expansion\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 257\u001b[0;31m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconv_pw\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    258\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mbn1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    259\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mact1\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/layers/mixed_conv2d.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, x)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0mx_split\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m         \u001b[0mx_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_split\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/timm/models/layers/mixed_conv2d.py\u001b[0m in \u001b[0;36m<listcomp>\u001b[0;34m(.0)\u001b[0m\n\u001b[1;32m     47\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     48\u001b[0m         \u001b[0mx_split\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplits\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m         \u001b[0mx_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0mc\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_split\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mc\u001b[0m \u001b[0;32min\u001b[0m \u001b[0menumerate\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m     50\u001b[0m         \u001b[0mx\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcat\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m1\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m     51\u001b[0m         \u001b[0;32mreturn\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/module.py\u001b[0m in \u001b[0;36m_call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m    725\u001b[0m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_slow_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    726\u001b[0m         \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 727\u001b[0;31m             \u001b[0mresult\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m*\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    728\u001b[0m         for hook in itertools.chain(\n\u001b[1;32m    729\u001b[0m                 \u001b[0m_global_forward_hooks\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36mforward\u001b[0;34m(self, input)\u001b[0m\n\u001b[1;32m    421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    422\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 423\u001b[0;31m         \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_conv_forward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mweight\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m    424\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    425\u001b[0m \u001b[0;32mclass\u001b[0m \u001b[0mConv3d\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0m_ConvNd\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;32m/opt/conda/lib/python3.7/site-packages/torch/nn/modules/conv.py\u001b[0m in \u001b[0;36m_conv_forward\u001b[0;34m(self, input, weight)\u001b[0m\n\u001b[1;32m    418\u001b[0m                             _pair(0), self.dilation, self.groups)\n\u001b[1;32m    419\u001b[0m         return F.conv2d(input, weight, self.bias, self.stride,\n\u001b[0;32m--> 420\u001b[0;31m                         self.padding, self.dilation, self.groups)\n\u001b[0m\u001b[1;32m    421\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m    422\u001b[0m     \u001b[0;32mdef\u001b[0m \u001b[0mforward\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0minput\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;34m->\u001b[0m \u001b[0mTensor\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
      "\u001b[0;31mKeyboardInterrupt\u001b[0m: "
     ]
    }
   ],
   "source": [
    "# Final - Evaluate Validation Dataset \n",
    "for model in models:\n",
    "    model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# store information for evaluation\n",
    "future_coords_offsets_pd = []\n",
    "timestamps = []\n",
    "# coordinates ground truth\n",
    "valid_coords_gts = []\n",
    "# target avalabilities\n",
    "target_avail_pd = []\n",
    "agent_ids = []\n",
    "progress_bar = tqdm(valid_dataloader)\n",
    "for data in progress_bar:\n",
    "    \n",
    "    outputs = multiple(models, data)\n",
    "    future_coords_offsets_pd.append(outputs.copy())\n",
    "    timestamps.append(data[\"timestamp\"].numpy().copy())\n",
    "    agent_ids.append(data[\"track_id\"].numpy().copy())\n",
    "    valid_coords_gts.append(data[\"target_positions\"].numpy().copy())\n",
    "    target_avail_pd.append(data[\"target_availabilities\"].unsqueeze(-1).numpy().copy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Final Evaluation -- Kmeans"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "metadata": {},
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 1167/1167 [13:42<00:00,  1.42it/s]\n"
     ]
    }
   ],
   "source": [
    "# Final - Evaluate Validation Dataset \n",
    "for model in models:\n",
    "    model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# store information for evaluation\n",
    "future_coords_offsets_pd = []\n",
    "timestamps = []\n",
    "# coordinates ground truth\n",
    "valid_coords_gts = []\n",
    "# target avalabilities\n",
    "target_avail_pd = []\n",
    "agent_ids = []\n",
    "progress_bar = tqdm(valid_dataloader)\n",
    "for data in progress_bar:\n",
    "    \n",
    "    outputs = kmeans_ensemble(models, data)\n",
    "    future_coords_offsets_pd.append(outputs.copy())\n",
    "    timestamps.append(data[\"timestamp\"].numpy().copy())\n",
    "    agent_ids.append(data[\"track_id\"].numpy().copy())\n",
    "    valid_coords_gts.append(data[\"target_positions\"].numpy().copy())\n",
    "    target_avail_pd.append(data[\"target_availabilities\"].unsqueeze(-1).numpy().copy())"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Concatenate Data"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 21,
   "metadata": {
    "_kg_hide-output": true
   },
   "outputs": [],
   "source": [
    "timestamps_concat = np.concatenate(timestamps)\n",
    "track_ids_concat = np.concatenate(agent_ids)\n",
    "coords_concat = np.concatenate(future_coords_offsets_pd)\n",
    "gt_valid_final = np.concatenate(valid_coords_gts)\n",
    "target_avail_concat = np.concatenate(target_avail_pd)\n",
    "\n",
    "# gt_valid_2D = np.reshape(gt_valid_final, (70000, 100))\n",
    "\n",
    "# log_like = neg_multi_log_likelihood (\n",
    "#     ground_truth=gt_valid_2D,\n",
    "#     pred=coords_concat,\n",
    "#     confidences=np.array([1,0,0]),\n",
    "#     avails=target_avail_concat\n",
    "# )\n",
    "\n",
    "# print(\"Negative multi Likelihood is:\", log_like)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "8"
      ]
     },
     "execution_count": 43,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "len(models)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Compute metrics"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 42,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "neg_multi_log_likelihood 34.93464970678191\n",
      "time_displace [0.03780884 0.06774837 0.09400525 0.11911281 0.14333772 0.16555054\n",
      " 0.18603317 0.20607193 0.22537911 0.2399516  0.25563167 0.27254414\n",
      " 0.28808102 0.30089175 0.31353564 0.32429114 0.33827663 0.34997946\n",
      " 0.36289177 0.37681631 0.38824709 0.40050204 0.41532554 0.42659337\n",
      " 0.43842456 0.45155713 0.46308229 0.47460614 0.48670486 0.50179859\n",
      " 0.51000651 0.52614248 0.54252046 0.55251546 0.56777636 0.57811409\n",
      " 0.58324775 0.59759121 0.61509243 0.62943898 0.64193703 0.65862458\n",
      " 0.6745409  0.69017376 0.70823359 0.72041866 0.73800779 0.75194556\n",
      " 0.76496447 0.79003944]\n"
     ]
    }
   ],
   "source": [
    "# Negative Log Likelihood Metrics\n",
    "eval_gt_path = \"valid_gt_7000.csv\"\n",
    "pred_path = 'submission_mixnetl_ensemble.csv'\n",
    "\n",
    "# generate ground truth csv\n",
    "write_gt_csv(\n",
    "    csv_path=eval_gt_path, \n",
    "    timestamps=timestamps_concat, \n",
    "    track_ids=track_ids_concat, \n",
    "    coords=gt_valid_final, \n",
    "    avails=target_avail_concat.squeeze(-1)\n",
    ")\n",
    "\n",
    "num_examples = gt_valid_final.shape[0]\n",
    "confidence = np.array([0.33,0.33,0.34])\n",
    "confidences= np.empty((num_examples, 3))\n",
    "for i in range(num_examples):\n",
    "    confidences[i] = confidence\n",
    "    \n",
    "    \n",
    "# submission.csv\n",
    "write_pred_csv(pred_path,\n",
    "               timestamps=timestamps_concat,\n",
    "               track_ids=track_ids_concat,\n",
    "               coords=coords_concat,\n",
    "               confs=confidences\n",
    "              )\n",
    "\n",
    "metrics = compute_metrics_csv(eval_gt_path, pred_path, [neg_multi_log_likelihood, time_displace])\n",
    "for metric_name, metric_mean in metrics.items():\n",
    "    print(metric_name, metric_mean)\n",
    "    \n",
    "# Save Metric\n",
    "metric = 'metric_mixnetl_kmeans.npy'\n",
    "np.save(metric,metrics)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Metric Single_Model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 15,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "neg_multi_log_likelihood 53.71172870034174\n",
      "time_displace [0.03773215 0.06684401 0.09530268 0.12201825 0.14818366 0.17338283\n",
      " 0.19758767 0.22078182 0.24258761 0.26052195 0.28114266 0.30400775\n",
      " 0.32426746 0.34373423 0.36105989 0.37776627 0.3983554  0.41617274\n",
      " 0.43362092 0.45443441 0.47207964 0.49163032 0.51341603 0.53085735\n",
      " 0.5490904  0.5684535  0.58860053 0.60813105 0.62938755 0.65290465\n",
      " 0.66944903 0.69292696 0.71680625 0.73534621 0.75950115 0.77799942\n",
      " 0.78870733 0.80942547 0.83417625 0.85535263 0.8724096  0.89552867\n",
      " 0.91580872 0.93774403 0.96052092 0.97349923 0.99529021 1.00818272\n",
      " 1.02306923 1.04947916]\n"
     ]
    }
   ],
   "source": [
    "# Negative Log Likelihood Metrics\n",
    "eval_gt_path = \"valid_gt_7000.csv\"\n",
    "pred_path = 'submission_mixnetl_42000.csv'\n",
    "\n",
    "# generate ground truth csv\n",
    "write_gt_csv(\n",
    "    csv_path=eval_gt_path, \n",
    "    timestamps=timestamps_concat, \n",
    "    track_ids=track_ids_concat, \n",
    "    coords=gt_valid_final, \n",
    "    avails=target_avail_concat.squeeze(-1)\n",
    ")\n",
    "\n",
    "    \n",
    "    \n",
    "# submission.csv\n",
    "write_pred_csv(pred_path,\n",
    "               timestamps=timestamps_concat,\n",
    "               track_ids=track_ids_concat,\n",
    "               coords=coords_concat,\n",
    "               # confs=confidences\n",
    "              )\n",
    "\n",
    "metrics = compute_metrics_csv(eval_gt_path, pred_path, [neg_multi_log_likelihood, time_displace])\n",
    "for metric_name, metric_mean in metrics.items():\n",
    "    print(metric_name, metric_mean)\n",
    "    \n",
    "# Save Metric\n",
    "metric = 'metric_mixnetl_42000.npy'\n",
    "np.save(metric,metrics)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 23,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/html": [
       "<a href='metric_mixnetl_25000_kmeans.npy' target='_blank'>metric_mixnetl_25000_kmeans.npy</a><br>"
      ],
      "text/plain": [
       "/kaggle/working/metric_mixnetl_25000_kmeans.npy"
      ]
     },
     "execution_count": 23,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "os.chdir(r'/kaggle/working')\n",
    "FileLink(metric)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FileLink(eval_gt_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "FileLink(pred_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Plot Prediction Tractories"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# Uncomment to choose satelliter or semantic rasterizer\n",
    "# validate_cfg[\"raster_params\"][\"map_type\"] = \"py_satellite\"\n",
    "validate_cfg[\"raster_params\"][\"map_type\"] = \"py_semantic\"\n",
    "\n",
    "rast = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "eval_ego_dataset = EgoDataset(validate_cfg, valid_dataset.dataset, rast)\n",
    "num_frames = 2 # randomly pick _ frames\n",
    "random_frames = np.random.randint(0,len(eval_ego_dataset)-1, (num_frames,))\n",
    "\n",
    "for frame_number in random_frames:  \n",
    "    agent_indices = valid_dataset.get_frame_indices(frame_number) \n",
    "    if not len(agent_indices):\n",
    "        continue\n",
    "\n",
    "    # get AV point-of-view frame\n",
    "    data_ego = eval_ego_dataset[frame_number]\n",
    "    im_ego = rasterizer.to_rgb(data_ego[\"image\"].transpose(1, 2, 0))\n",
    "    center = np.asarray(validate_cfg[\"raster_params\"][\"ego_center\"]) * validate_cfg[\"raster_params\"][\"raster_size\"]\n",
    "    \n",
    "    predicted_positions = []\n",
    "    target_positions = []\n",
    "\n",
    "    for v_index in agent_indices:\n",
    "        data_agent = valid_dataset[v_index]\n",
    "\n",
    "        out_net = model(torch.from_numpy(data_agent[\"image\"]).unsqueeze(0).to(device))\n",
    "        out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()\n",
    "        # store absolute world coordinates\n",
    "        predicted_positions.append(transform_points(out_pos, data_agent[\"world_from_agent\"]))\n",
    "        # retrieve target positions from the GT and store as absolute coordinates\n",
    "        track_id, timestamp = data_agent[\"track_id\"], data_agent[\"timestamp\"]\n",
    "        target_positions.append(transform_points(data_agent[\"target_positions\"], data_agent[\"world_from_agent\"]) )\n",
    "\n",
    "    # convert coordinates to AV point-of-view so we can draw them\n",
    "    predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego[\"raster_from_world\"])\n",
    "    target_positions = transform_points(np.concatenate(target_positions), data_ego[\"raster_from_world\"])\n",
    "    \n",
    "    # make sure ground truth and prediction have the same data size\n",
    "    assert len(target_positions) == len(predicted_positions)\n",
    "    \n",
    "    # draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)\n",
    "    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = 6, 6\n",
    "    plt.imshow(im_ego[::-1])\n",
    "    plt.show()\n",
    "    \n",
    "    draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)\n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = 6, 6\n",
    "    plt.imshow(im_ego[::-1])\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Plot Ensemble Predictions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "model.eval()\n",
    "torch.set_grad_enabled(False)\n",
    "\n",
    "# Uncomment to choose satelliter or semantic rasterizer\n",
    "# validate_cfg[\"raster_params\"][\"map_type\"] = \"py_satellite\"\n",
    "validate_cfg[\"raster_params\"][\"map_type\"] = \"py_semantic\"\n",
    "\n",
    "rast = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "eval_ego_dataset = EgoDataset(validate_cfg, valid_dataset.dataset, rast)\n",
    "num_frames = 1 # randomly pick _ frames\n",
    "random_frames = np.random.randint(0,len(eval_ego_dataset)-1, (num_frames,))\n",
    "\n",
    "for frame_number in random_frames:  \n",
    "    agent_indices = valid_dataset.get_frame_indices(frame_number) \n",
    "    if not len(agent_indices):\n",
    "        continue\n",
    "\n",
    "    # get AV point-of-view frame\n",
    "    data_ego = eval_ego_dataset[frame_number]\n",
    "    im_ego = rasterizer.to_rgb(data_ego[\"image\"].transpose(1, 2, 0))\n",
    "    center = np.asarray(validate_cfg[\"raster_params\"][\"ego_center\"]) * validate_cfg[\"raster_params\"][\"raster_size\"]\n",
    "    \n",
    "    target_positions = []\n",
    "    \n",
    "    for model in models:\n",
    "        \n",
    "        predicted_positions = []\n",
    "        for v_index in agent_indices:\n",
    "            data_agent = valid_dataset[v_index]\n",
    "\n",
    "            out_net = model(torch.from_numpy(data_agent[\"image\"]).unsqueeze(0).to(device))\n",
    "            out_pos = out_net[0].reshape(-1, 2).detach().cpu().numpy()\n",
    "            # store absolute world coordinates\n",
    "            predicted_positions.append(transform_points(out_pos, data_agent[\"world_from_agent\"]))\n",
    "            # retrieve target positions from the GT and store as absolute coordinates\n",
    "            track_id, timestamp = data_agent[\"track_id\"], data_agent[\"timestamp\"]\n",
    "\n",
    "        # convert coordinates to AV point-of-view so we can draw them\n",
    "        predicted_positions = transform_points(np.concatenate(predicted_positions), data_ego[\"raster_from_world\"])\n",
    "\n",
    "        # make sure ground truth and prediction have the same data size\n",
    "        # different predicted points\n",
    "        draw_trajectory(im_ego, predicted_positions, PREDICTED_POINTS_COLOR)\n",
    "        \n",
    "    target_positions.append(transform_points(data_agent[\"target_positions\"], data_agent[\"world_from_agent\"]) )\n",
    "    target_positions = transform_points(np.concatenate(target_positions), data_ego[\"raster_from_world\"])\n",
    "    assert len(target_positions) == len(predicted_positions)\n",
    "    draw_trajectory(im_ego, target_positions, TARGET_POINTS_COLOR)    \n",
    "    \n",
    "    plt.rcParams['figure.figsize'] = 6, 6\n",
    "    print(im_ego.shape)\n",
    "    plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from IPython.display import display, clear_output\n",
    "import PIL\n",
    "from IPython.display import Image\n",
    "\n",
    "num_frames = 1 # randomly pick _ frames\n",
    "random_frames = np.random.randint(0,10000, (num_frames,))\n",
    "\n",
    "validate_cfg[\"raster_params\"][\"map_type\"] = \"py_semantic\"\n",
    "rast = build_rasterizer(validate_cfg, dm)\n",
    "\n",
    "for scene_idx in random_frames:\n",
    "    indexes = eval_ego_dataset.get_scene_indices(scene_idx)\n",
    "    images = []\n",
    "\n",
    "    for idx in indexes:\n",
    "\n",
    "        data = eval_ego_dataset[idx]\n",
    "        im = data[\"image\"].transpose(1, 2, 0)\n",
    "        im = eval_ego_dataset.rasterizer.to_rgb(im)\n",
    "        target_positions_pixels = transform_points(data_agent[\"target_positions\"], data_agent[\"raster_from_agent\"])\n",
    "        center_in_pixels = np.asarray(validate_cfg[\"raster_params\"][\"ego_center\"]) * validate_cfg[\"raster_params\"][\"raster_size\"]\n",
    "        # draw_trajectory(im, target_positions_pixels, TARGET_POINTS_COLOR)\n",
    "        clear_output(wait=True)\n",
    "        images.append(im)\n",
    "\n",
    "    write_gif(\n",
    "    output_filepath=\"output_scene.gif\",\n",
    "    images=images,\n",
    "    resolution=(300,300),\n",
    "    )\n",
    "    \n",
    "Image(\"output_scene.gif\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.10"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
